"""Script to output the correct installation command for cut-cross-entropy."""

import importlib.util
import sys

try:
    import torch
except ImportError as exc:
    raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V

USE_UV = "--uv" in sys.argv[1:]

v = V(torch.__version__)

# no cut-cross-entropy support for torch < 2.4.0
if v < V("2.4.0"):
    print("")
    sys.exit(0)

cce_spec = importlib.util.find_spec("cut_cross_entropy")

UNINSTALL_PREFIX = ""
if cce_spec:
    if not importlib.util.find_spec("cut_cross_entropy.transformers"):
        UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "

UV_PREFIX = "uv " if USE_UV else ""

print(
    UNINSTALL_PREFIX
    + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154"'
)
